﻿/*******************************************************************************
* Copyright (C) sykjwh.cn
* 
* Author: liuxiang
* Create Date: 2019/01/16
* Description: Automated building by liuxiang20041986@qq.com 
* http://www.sykjwh.cn/
*********************************************************************************/

using System;
using System.Collections.Generic;

namespace Sykj.Repository
{
    /// <summary>
    /// 工作单元，处理数据库事物，不适用于返回增ID
    /// </summary>
    public class UnitOfWork : IUnitOfWork
    {
        private Dictionary<object, Func<int>> _addEntities;
        private Dictionary<object, Func<int>> _updateEntities;
        private Dictionary<object, Func<int>> _deleteEntities;

        //定义数据访问上下文对象
        private SyDbContext _dbContext;

        /// <summary>
        /// 构造函数
        /// </summary>
        /// <param name="dbContext"></param>
        public UnitOfWork(SyDbContext dbContext)
        {
            _dbContext = dbContext;

            _addEntities = new Dictionary<object, Func<int>>();
            _updateEntities = new Dictionary<object, Func<int>>();
            _deleteEntities = new Dictionary<object, Func<int>>();
        }

        /// <summary>
        /// 添加
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="entity"></param>
        public void Add<T>(T entity) where T : class
        {
            _addEntities.Add(entity, () =>
            {
                _dbContext.Add(entity);
                return _dbContext.SaveChanges();
            });
        }

        /// <summary>
        /// 修改
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="entity"></param>
        public void Update<T>(T entity) where T : class
        {
            _updateEntities.Add(entity, () =>
            {
                _dbContext.Update(entity);
                return _dbContext.SaveChanges();
            });
        }

        /// <summary>
        /// 删除
        /// </summary>
        /// <typeparam name="T"></typeparam>
        /// <param name="entity"></param>
        public void Delete<T>(T entity) where T : class
        {
            _deleteEntities.Add(entity, () =>
            {
                _dbContext.Remove(entity);
                return _dbContext.SaveChanges();
            });
        }

        /// <summary>
        /// 提交事物
        /// </summary>
        /// <returns></returns>
        public int Commit()
        {
            int count = 0;
            //创建事物
            using (var tran = _dbContext.Database.BeginTransaction())
            {
                try
                {
                    foreach (var entity in _deleteEntities.Keys)
                    {
                        count += _deleteEntities[entity]();
                    }

                    foreach (var entity in _updateEntities.Keys)
                    {
                        count += _updateEntities[entity]();
                    }

                    foreach (var entity in _addEntities.Keys)
                    {
                        count += _addEntities[entity]();
                    }

                    tran.Commit();

                    _deleteEntities.Clear();
                    _updateEntities.Clear();
                    _addEntities.Clear();
                }
                catch (Exception ex)
                {
                    tran.Rollback();
                    count = 0;
                    throw ex;
                }
            }
            return count;
        }
    }
}
